{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training an Encrypted Neural Network\n",
    "\n",
    "In this tutorial, we will walk through an example of how we can train a neural network with CrypTen. This is particularly relevant for the <i>Feature Aggregation</i>, <i>Data Labeling</i> and <i>Data Augmentation</i> use cases. We will focus on the usual two-party setting and show how we can train an accurate neural network for digit classification on the MNIST data.\n",
    "\n",
    "For concreteness, this tutorial will step through the <i>Feature Aggregation</i> use cases: Alice and Bob each have part of the features of the data set, and wish to train a neural network on their combined data, while keeping their data private. \n",
    "\n",
    "## Setup\n",
    "As usual, we'll begin by importing and initializing the `crypten` and `torch` libraries.  \n",
    "\n",
    "We will use the MNIST dataset to demonstrate how Alice and Bob can learn without revealing protected information. For reference, the feature size of each example in the MNIST data is `28 x 28`. Let's assume Alice has the first `28 x 20` features and Bob has last `28 x 8` features. One way to think of this split is that Alice has the (roughly) top 2/3rds of each image, while Bob has the bottom 1/3rd of each image. We'll again use our helper script `mnist_utils.py` that downloads the publicly available MNIST data, and splits the data as required.\n",
    "\n",
    "For simplicity, we will restrict our problem to binary classification: we'll simply learn how to distinguish between 0 and non-zero digits. For speed of execution in the notebook, we will only create a dataset of a 100 examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:module 'torchvision.models.mobilenet' has no attribute 'ConvBNReLU'\n"
     ]
    }
   ],
   "source": [
    "import crypten\n",
    "import torch\n",
    "\n",
    "crypten.init()\n",
    "torch.set_num_threads(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run ./mnist_utils.py --option features --reduced 100 --binary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll define the network architecture below, and then describe how to train it on encrypted data in the next section. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "#Define an example network\n",
    "class ExampleNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ExampleNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)\n",
    "        self.fc1 = nn.Linear(16 * 12 * 12, 100)\n",
    "        self.fc2 = nn.Linear(100, 2) # For binary classification, final layer needs only 2 outputs\n",
    " \n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = F.relu(out)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = out.view(-1, 16 * 12 * 12)\n",
    "        out = self.fc1(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        return out\n",
    "    \n",
    "crypten.common.serial.register_safe_class(ExampleNet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encrypted Training\n",
    "\n",
    "After all the material we've covered in earlier tutorials, we only need to know a few additional items for encrypted training. We'll first discuss how the training loop in CrypTen differs from PyTorch. Then, we'll go through a complete example to illustrate training on encrypted data from end-to-end.\n",
    "\n",
    "### How does CrypTen training differ from PyTorch training?\n",
    "\n",
    "There are two main ways implementing a CrypTen training loop differs from a PyTorch training loop. We'll describe these items first, and then illustrate them with small examples below.\n",
    "\n",
    "<i>(1) Use one-hot encoding</i>: CrypTen training requires all labels to use one-hot encoding. This means that when using standard datasets such as MNIST, we need to modify the labels to use one-hot encoding.\n",
    "\n",
    "<i>(2) Directly update parameters</i>: CrypTen does not use the PyTorch optimizers. Instead, CrypTen implements encrypted SGD by implementing its own `backward` function, followed by directly updating the parameters. As we will see below, using SGD in CrypTen is very similar to using the PyTorch optimizers.\n",
    "\n",
    "We now show some small examples to illustrate these differences. As before, we will assume Alice has the rank 0 process and Bob has the rank 1 process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define source argument values for Alice and Bob\n",
    "ALICE = 0\n",
    "BOB = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Alice's data \n",
    "data_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brianknott/Desktop/CrypTen/crypten/nn/onnx_converter.py:153: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /Users/distiller/project/conda/conda-bld/pytorch_1616554799287/work/torch/csrc/utils/tensor_numpy.cpp:143.)\n",
      "  for t in onnx_model.graph.initializer\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Graph encrypted module"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We'll now set up the data for our small example below\n",
    "# For illustration purposes, we will create toy data\n",
    "# and encrypt all of it from source ALICE\n",
    "x_small = torch.rand(100, 1, 28, 28)\n",
    "y_small = torch.randint(1, (100,))\n",
    "\n",
    "# Transform labels into one-hot encoding\n",
    "label_eye = torch.eye(2)\n",
    "y_one_hot = label_eye[y_small]\n",
    "\n",
    "# Transform all data to CrypTensors\n",
    "x_train = crypten.cryptensor(x_small, src=ALICE)\n",
    "y_train = crypten.cryptensor(y_one_hot)\n",
    "\n",
    "# Instantiate and encrypt a CrypTen model\n",
    "model_plaintext = ExampleNet()\n",
    "dummy_input = torch.empty(1, 1, 28, 28)\n",
    "model = crypten.nn.from_pytorch(model_plaintext, dummy_input)\n",
    "model.encrypt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0 Loss: 0.5695\n",
      "Epoch: 1 Loss: 0.5417\n"
     ]
    }
   ],
   "source": [
    "# Example: Stochastic Gradient Descent in CrypTen\n",
    "\n",
    "model.train() # Change to training mode\n",
    "loss = crypten.nn.MSELoss() # Choose loss functions\n",
    "\n",
    "# Set parameters: learning rate, num_epochs\n",
    "learning_rate = 0.001\n",
    "num_epochs = 2\n",
    "\n",
    "# Train the model: SGD on encrypted data\n",
    "for i in range(num_epochs):\n",
    "\n",
    "    # forward pass\n",
    "    output = model(x_train)\n",
    "    loss_value = loss(output, y_train)\n",
    "    \n",
    "    # set gradients to zero\n",
    "    model.zero_grad()\n",
    "\n",
    "    # perform backward pass\n",
    "    loss_value.backward()\n",
    "\n",
    "    # update parameters\n",
    "    model.update_parameters(learning_rate) \n",
    "    \n",
    "    # examine the loss after each epoch\n",
    "    print(\"Epoch: {0:d} Loss: {1:.4f}\".format(i, loss_value.get_plain_text()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A Complete Example\n",
    "\n",
    "We now put these pieces together for a complete example of training a network in a multi-party setting. \n",
    "\n",
    "As in Tutorial 3, we'll assume Alice has the rank 0 process, and Bob has the rank 1 process; so we'll load and encrypt Alice's data with `src=0`, and load and encrypt Bob's data with `src=1`. We'll then initialize a plaintext model and convert it to an encrypted model, just as we did in Tutorial 4. We'll finally define our loss function, training parameters, and run SGD on the encrypted data. For the purposes of this tutorial we train on 100 samples; training should complete in ~3 minutes per epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([100, 28, 20])\n",
      "torch.Size([100, 28, 8])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import crypten.mpc as mpc\n",
    "import crypten.communicator as comm\n",
    "\n",
    "# Convert labels to one-hot encoding\n",
    "# Since labels are public in this use case, we will simply use them from loaded torch tensors\n",
    "labels = torch.load('/tmp/train_labels.pth')\n",
    "labels = labels.long()\n",
    "labels_one_hot = label_eye[labels]\n",
    "\n",
    "@mpc.run_multiprocess(world_size=2)\n",
    "def run_encrypted_training():\n",
    "    # Load data:\n",
    "    x_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)\n",
    "    x_bob_enc = crypten.load_from_party('/tmp/bob_train.pth', src=BOB)\n",
    "    \n",
    "    crypten.print(x_alice_enc.size())\n",
    "    crypten.print(x_bob_enc.size())\n",
    "    \n",
    "    # Combine the feature sets: identical to Tutorial 3\n",
    "    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)\n",
    "    \n",
    "    # Reshape to match the network architecture\n",
    "    x_combined_enc = x_combined_enc.unsqueeze(1)\n",
    "    \n",
    "    \n",
    "    # Commenting out due to intermittent failure in PyTorch codebase\n",
    "    \"\"\"\n",
    "    # Initialize a plaintext model and convert to CrypTen model\n",
    "    pytorch_model = ExampleNet()\n",
    "    model = crypten.nn.from_pytorch(pytorch_model, dummy_input)\n",
    "    model.encrypt()\n",
    "    # Set train mode\n",
    "    model.train()\n",
    "  \n",
    "    # Define a loss function\n",
    "    loss = crypten.nn.MSELoss()\n",
    "\n",
    "    # Define training parameters\n",
    "    learning_rate = 0.001\n",
    "    num_epochs = 2\n",
    "    batch_size = 10\n",
    "    num_batches = x_combined_enc.size(0) // batch_size\n",
    "    \n",
    "    rank = comm.get().get_rank()\n",
    "    for i in range(num_epochs): \n",
    "        crypten.print(f\"Epoch {i} in progress:\")       \n",
    "        \n",
    "        for batch in range(num_batches):\n",
    "            # define the start and end of the training mini-batch\n",
    "            start, end = batch * batch_size, (batch + 1) * batch_size\n",
    "                                    \n",
    "            # construct CrypTensors out of training examples / labels\n",
    "            x_train = x_combined_enc[start:end]\n",
    "            y_batch = labels_one_hot[start:end]\n",
    "            y_train = crypten.cryptensor(y_batch, requires_grad=True)\n",
    "            \n",
    "            # perform forward pass:\n",
    "            output = model(x_train)\n",
    "            loss_value = loss(output, y_train)\n",
    "            \n",
    "            # set gradients to \"zero\" \n",
    "            model.zero_grad()\n",
    "\n",
    "            # perform backward pass: \n",
    "            loss_value.backward()\n",
    "\n",
    "            # update parameters\n",
    "            model.update_parameters(learning_rate)\n",
    "            \n",
    "            # Print progress every batch:\n",
    "            batch_loss = loss_value.get_plain_text()\n",
    "            crypten.print(f\"\\tBatch {(batch + 1)} of {num_batches} Loss {batch_loss.item():.4f}\")\n",
    "    \"\"\"\n",
    "\n",
    "run_encrypted_training()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the average batch loss decreases across the epochs, as we expect during training.\n",
    "\n",
    "This completes our tutorial. Before exiting this tutorial, please clean up the files generated using the following code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filenames = ['/tmp/alice_train.pth', \n",
    "             '/tmp/bob_train.pth', \n",
    "             '/tmp/alice_test.pth',\n",
    "             '/tmp/bob_test.pth', \n",
    "             '/tmp/train_labels.pth',\n",
    "             '/tmp/test_labels.pth']\n",
    "\n",
    "for fn in filenames:\n",
    "    if os.path.exists(fn): os.remove(fn)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
